import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

dataset = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = dataset.load_data()

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

# 在同一画布下显示测试集的25张图
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i])
    plt.xlabel(class_names[y_train[i][0]])
plt.show()

x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = tf.expand_dims(x_train, -1)
y_train = np.float32(tf.keras.utils.to_categorical(y_train, num_classes=10))
x_test = tf.expand_dims(x_test, -1)
y_test = np.float32(tf.keras.utils.to_categorical(y_test, num_classes=10))

batch_size = 256
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size).shuffle(batch_size * 10)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

input_xs = tf.keras.Input([32, 32, 3])
conv = tf.keras.layers.Conv2D(64, 3, padding="SAME", activation=tf.nn.relu)(input_xs)
conv = tf.keras.layers.BatchNormalization()(conv)
conv = tf.keras.layers.Conv2D(64, 3, padding="SAME", activation=tf.nn.relu)(conv)
conv = tf.keras.layers.MaxPool2D(strides=[1, 1])(conv)
conv = tf.keras.layers.Conv2D(128, 3, padding="SAME", activation=tf.nn.relu)(conv)
conv = tf.keras.layers.MaxPool2D(strides=[1, 1])(conv)
flat = tf.keras.layers.Flatten()(conv)
conv = tf.keras.layers.Dropout(0.2)(conv)
dense = tf.keras.layers.Dense(256, activation=tf.nn.relu)(flat)
dense = tf.keras.layers.Dropout(0.2)(dense)
logits = tf.keras.layers.Dense(10, activation=tf.nn.softmax)(dense)
model = tf.keras.Model(inputs=input_xs, outputs=logits)
print(model.summary())

model.compile(optimizer=tf.optimizers.Adam(1e-3), loss=tf.losses.categorical_crossentropy, metrics=['accuracy'])
history = model.fit(train_dataset, epochs=15, shuffle=True, validation_data=(x_test, y_test))

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()

model.save("../data/model.h5")
score = model.evaluate(test_dataset)
print("last score:", score[1])
